Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Dec 10, 2025

Description

Adds nvte_grouped_gemm API using cuBLASLt grouped matmul for batched GEMM on tensors with varying shapes. A GPU kernel (setup_grouped_gemm_kernel) converts NVTEGroupedTensor format (contiguous buffer + offsets) to cuBLAS requirements (pointer arrays + per-matrix M/N/K).

New API

void nvte_grouped_gemm(
                       int transa, int transb, const NVTETensor alpha, const NVTEGroupedTensor A,
                       const NVTEGroupedTensor B, const NVTETensor beta, const NVTEGroupedTensor C,
                       NVTEGroupedTensor D, NVTETensor workspace_setup, NVTETensor workspace_cublas,
                       NVTEGroupedMatmulConfig config, cudaStream_t stream) 

Computes D = alpha * op(A) @ op(B) + beta * C for groups of matrices with potentially different shapes.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • GPU setup kernel computing pointers/dims from grouped tensor metadata
  • FP8 support with scale_inv handling and TN layout selection on Hopper
  • GroupedGemmSetupWorkspace struct for cuBLAS workspace layout
  • Tests in test_grouped_gemm.cu comparing against nvte_multi_tensor_gemm (FP8/BF16, various shapes and transpose layouts)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL pggPL changed the title [common] Add support for cublasLt GEMM for GroupedTensor [common] Add support for cuBLASLt GEMM for GroupedTensor Dec 10, 2025
pre-commit-ci bot and others added 3 commits December 10, 2025 14:32
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM
- Fix random padding in tests to ensure 16-byte alignment for all dtypes
- Reorder GroupedGemmSetupWorkspace members for natural alignment
- Remove debug prints

Signed-off-by: Pawel Gadzinski <[email protected]>
@ptrendx ptrendx added the MoE label Dec 10, 2025
@ptrendx ptrendx linked an issue Dec 10, 2025 that may be closed by this pull request
pggPL and others added 2 commits December 10, 2025 22:34
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 10, 2025

/te-ci L0

@pggPL pggPL marked this pull request as ready for review December 10, 2025 21:43
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 10, 2025

Greptile Summary

Adds nvte_grouped_gemm API leveraging cuBLASLt grouped matmul for efficient batched GEMM on tensors with varying shapes. The implementation includes:

  • Core functionality: GPU kernel (setup_grouped_gemm_kernel) bridges NVTEGroupedTensor format (contiguous buffer + offsets) to cuBLAS requirements (pointer arrays + per-matrix dimensions)
  • FP8 support: Handles scale_inv pointers and enforces TN layout selection on Hopper GPUs via operand selection logic
  • Configuration: GroupedMatmulConfig with optional avg_m/avg_n/avg_k hints for algorithm heuristics
  • Testing: Comprehensive tests compare against nvte_multi_tensor_gemm baseline across FP8/BF16, various shapes, and transpose configurations
  • Test infrastructure: New build_grouped_tensor helper with RAII memory management and random padding for alignment testing

The implementation follows established patterns from the codebase, includes proper validation, and has thorough test coverage.

Confidence Score: 4/5

  • This PR is safe to merge with minor considerations
  • Score reflects solid implementation with comprehensive testing, but has one logical issue in config.h attr_sizes that needs addressing. The code follows established patterns, includes proper validation and error handling, and has extensive test coverage comparing against baseline implementation. The const_cast usage and non-deterministic test padding are acceptable design choices.
  • Pay attention to transformer_engine/common/gemm/config.h for the attr_sizes calculation issue with std::optional

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu New cuBLASLt grouped GEMM implementation with FP8 support; includes GPU setup kernel, operand selection, and workspace management. Well-structured with comprehensive validation.
transformer_engine/common/include/transformer_engine/gemm.h API header additions for grouped GEMM: new config types, attributes, and main nvte_grouped_gemm function declaration with clear documentation.
tests/cpp/operator/test_grouped_gemm.cu Comprehensive test suite comparing grouped GEMM against multi-tensor GEMM baseline. Tests FP8/BF16, various shapes, transpose combinations, and null C case.
tests/cpp/test_common.cu Adds build_grouped_tensor helper with RAII memory management. Handles uniform/varying dimensions, computes offsets, includes random padding for alignment testing.

Sequence Diagram

sequenceDiagram
    participant User
    participant nvte_grouped_gemm
    participant Validation
    participant Operand Selection
    participant Setup Kernel
    participant cuBLASLt

    User->>nvte_grouped_gemm: Call with A, B, C, D, alpha, beta
    nvte_grouped_gemm->>Validation: Check SM >= 100 (Blackwell)
    nvte_grouped_gemm->>Validation: validate_grouped_gemm_inputs()
    Validation-->>nvte_grouped_gemm: OK
    
    nvte_grouped_gemm->>Operand Selection: select_grouped_operand(A, transa)
    Operand Selection->>Operand Selection: Check FP8 TN layout requirements
    Operand Selection->>Operand Selection: Choose row-wise vs column-wise data
    Operand Selection-->>nvte_grouped_gemm: A_sel (dptr, dtype, trans, use_columnwise)
    
    nvte_grouped_gemm->>Operand Selection: select_grouped_operand(B, transb)
    Operand Selection-->>nvte_grouped_gemm: B_sel (dptr, dtype, trans, use_columnwise)
    
    nvte_grouped_gemm->>Setup Kernel: Allocate setup workspace
    nvte_grouped_gemm->>Setup Kernel: launch_grouped_gemm_setup()
    Setup Kernel->>Setup Kernel: setup_grouped_gemm_kernel<<<blocks, threads>>>
    Note over Setup Kernel: Per-tensor computation:<br/>- Compute A/B/C/D pointers from offsets<br/>- Compute M/N/K from dimensions<br/>- Fill alpha_ptrs, beta_ptrs arrays
    Setup Kernel-->>nvte_grouped_gemm: Workspace populated
    
    nvte_grouped_gemm->>cuBLASLt: init_matrix_layouts(descA, descB, descC, descD)
    nvte_grouped_gemm->>cuBLASLt: init_matmul_desc(op_A, op_B)
    nvte_grouped_gemm->>cuBLASLt: set_fp8_scale_pointers() if FP8
    nvte_grouped_gemm->>cuBLASLt: select_grouped_gemm_algo() with avg hints
    cuBLASLt-->>nvte_grouped_gemm: Algorithm selected
    
    nvte_grouped_gemm->>cuBLASLt: cublasLtMatmul()
    Note over cuBLASLt: Execute grouped GEMM:<br/>D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i]
    cuBLASLt-->>nvte_grouped_gemm: GEMM complete
    
    nvte_grouped_gemm-->>User: Return
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. tests/cpp/operator/test_grouped_gemm.cu, line 75 (link)

    logic: missing columnwise_data in move assignment

  2. tests/cpp/operator/test_grouped_gemm.cu, line 336-337 (link)

    logic: kSameFirst and kSameLast test cases use identical shape tuples. Should kSameFirst and kSameLast have different shape patterns to properly test the respective scenarios?

  3. tests/cpp/operator/test_grouped_gemm.cu, line 367-378 (link)

    logic: missing case for InputCase::kFP8Delayed

  4. transformer_engine/common/include/transformer_engine/gemm.h, line 265-266 (link)

    style: The avg_m, avg_n, avg_k parameters are not documented in the function comment

    What do these average dimension parameters represent and how should they be computed?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from timmoon10 December 10, 2025 22:35
pggPL and others added 2 commits December 11, 2025 11:56
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers
- Simplify select_grouped_operand by removing dead code branches
- Add GroupedOperandSelection.tensor field to avoid passing tensor separately
- Extract set_fp8_scale_pointers and init_matrix_layouts helpers
- Add safety check for FP8 on Hopper column-wise fallback
- Support NULL C tensor when beta=0 (uses D as placeholder)
- Remove unused get_scale_inv() from test
- Add use_null_c test parameter and test case
- Fix documentation: alpha/beta are single element tensors only

Signed-off-by: Piotr Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 266 (link)

    syntax: Documentation incorrectly states alpha[i] and beta[i]. The implementation uses a single alpha and beta value for all matrices in the group (batch stride is 1 in init_matmul_desc at cublaslt_gemm.cu:1404), not per-matrix scaling.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

pggPL and others added 3 commits December 11, 2025 12:16
- Change alpha/beta from single values to per-matrix arrays
- Validate alpha/beta have exactly num_tensors elements
- Update kernel to index alpha_ptr[idx] and beta_ptr[idx]
- Move alpha/beta validation to validate_grouped_gemm_inputs
- Update tests to use per-matrix alpha/beta arrays
- Update documentation

Signed-off-by: Piotr Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Piotr Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 11, 2025

/te-ci

greptile-apps[bot]

This comment was marked as resolved.

Comment on lines 273 to 274
NVTEMatmulConfig config, cudaStream_t stream, const int64_t *avg_m,
const int64_t *avg_n, const int64_t *avg_k);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The average sizes seem like advanced configs that would be better to leave out of the top-level API. Can we move them inside NVTEMatmulConfig?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was suggestion of @ptrendx - customer may potentially want to use them if they know something more about the shapes. For example if there are multiple tensors of with k dimension D and one tensor with K dimension equal to 1, then it is potentially true that telling cublas that avg dim = D will result in better performance.

Copy link
Collaborator

@timmoon10 timmoon10 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely expose these options, but it would be better to put them in NVTEMatmulConfig rather than the top-level function signature. If you look at nvte_cublas_gemm:

void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream);

You see advanced options like pre_gelu_out and use_split_accumulator that are only needed for specialized cases. It's even worse for use_split_accumulator, since that's only relevant for FP8 on Hopper and now it's sitting uselessly in the API forever. Compare with the v2 API, which is much closer to the original GEMM API:
void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream);

The idea of NVTEMatmulConfig is to hold these advanced, kernel-specific options in a way where we can add or deprecate them easily without breaking API changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added analogous NVTEGroupedMatmulConfig.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth noting that cublas does not support pre_gelu_out and use_split_accumulator for grouped gemm.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/common/include/transformer_engine/gemm.h, line 256 (link)

    syntax: Documentation mentions non-existent config parameter - this parameter is not in the function signature on line 276-280

  2. tests/cpp/operator/test_grouped_gemm.cu, line 141 (link)

    logic: Alignment calculation appears incorrect - multiplying 16 bytes by 8 bits gives 128 bits, not the intended 16-byte alignment

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Dec 19, 2025

/te-ci L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 445 (link)

    style: Use const_cast sparingly - casting away const from a_base could be risky if the caller expects the data to remain unchanged

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  2. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 451-455 (link)

    style: The comment mentions "Test stores A" but this is implementation code, not test code - update comment for clarity

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  3. tests/cpp/operator/test_grouped_gemm.cu, line 282-283 (link)

    style: Commented code should be removed before merging

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  4. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 313-318 (link)

    style: Variable naming could be clearer - rowa/cola/rowb/colb could be confused with actual row/column counts vs storage dimensions

    Consider renaming to lda_rows/lda_cols etc. or adding clarifying comments

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

8 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) {
// logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first)
// In both cases, dividing by num_tensors gives the average
return static_cast<int64_t>(t->logical_shape.data[0]) / static_cast<int64_t>(t->num_tensors);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is logical_shape.data[0]? do we have access to this field if we want cuda graph (ie. without D2H copy)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logical_shape represents total shape of all tensors in groupped tensor. For example we can have grouped tensor of logical shape [10, 10] and tensors of shape [3, 10] and [7, 10] inside it. The logical shape must be constant to use cuda graphs.

Copy link
Collaborator

@zhongbozhu zhongbozhu Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this could be an issue for future cuda graph support of dynamic shapes MOE. Basically, how it works is like for input buffer shape [10, 10], this [10, 10] shape is static, but there will be an on-device tensor that says the splits will be [[3, 10], [3, 10]], and the actual valid shape will be [6, 10]. The excessive shape should be the gemm performance wasted (?)

That being said, this function is simply for getting a cublas heuristic, so it's nothing fundamental. Can we support injecting static values directly here? Assuming users know their model well, and they can manually set those values and avoid knowing these shapes on host, since that would mean D2H copy, therefore no cuda graph.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only problem may be that the wrong statistics are computed, not any cuda errors or incorrect memory read, am I right?

There is support for custom statistic injection here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong statistics can be avoided if we do zero pad, but that's another story, need to check with cublas folks

btw, does it PR have BF16 grouped gemm? is grouped gemm optimization sm100 only for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it supports bf16 and there are plans to support it on other devices also

pggPL and others added 2 commits January 7, 2026 16:22
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 449-453 (link)

    style: const_cast removes const correctness. While the comment says cuBLAS won't modify the data, this breaks type safety guarantees.

    Consider if cuBLASLt API could accept const void** for input matrices, or document why this is necessary.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  2. tests/cpp/operator/test_grouped_gemm.cu, line 136-145 (link)

    style: Random padding for alignment testing is good, but the random seed is hardcoded (12345).

    Verify tests are deterministic and document why this specific padding strategy was chosen.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  3. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 582-585 (link)

    style: Fast accumulator mode is always enabled for FP8, but split accumulator isn't supported by cuBLASLt grouped GEMM API.

    Check that this accuracy trade-off is documented in the public API so users are aware.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  4. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 514 (link)

    style: SM100 (Blackwell) requirement is very restrictive.

    Verify this is a hard requirement from cuBLASLt grouped GEMM API and document the limitation clearly for users.

8 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

pggPL and others added 4 commits January 7, 2026 16:34
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/common/gemm/config.h, line 47-48 (link)

    logic: attr_sizes calculation for optional fields includes full std::optional size rather than just int64_t

    when nvte_get_grouped_matmul_config_attribute copies value to buffer, it only copies sizeof(int64_t) (line 150-160 in config.cpp), but attr_sizes declares sizeof(avg_m) which is sizeof(std::optional<int64_t>) (typically 16 bytes, not 8)

  2. transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 449 (link)

    style: const_cast removes const qualifier from input pointers A, B, C

    while comment notes this is safe because cuBLAS won't modify the data, this relies on cuBLAS internal behavior and could break if cuBLAS implementation changes - consider verifying this is documented in cuBLAS API

  3. tests/cpp/test_common.cu, line 1093-1102 (link)

    style: random padding logic may generate non-deterministic test behavior

    while the seed is fixed (12345), random padding affects memory layout which could expose alignment issues inconsistently across test runs - consider documenting this is intentional for broader coverage

9 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Additional Comments (5)

transformer_engine/common/gemm/config.h, line 48
attr_sizes incorrectly uses sizeof(avg_m) which is sizeof(std::optional<int64_t>) (typically 16 bytes), but the get/set functions treat it as sizeof(int64_t) (8 bytes)

  static constexpr size_t attr_sizes[] = {sizeof(int64_t), sizeof(int64_t), sizeof(int64_t),
                                          sizeof(int)};

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 514
SM100+ check inconsistent with documentation. Documentation states "Blackwell (SM100) or newer" but code checks >= 100. This allows SM100 (generic) but excludes SM100a/103a (actual Blackwell variants). The version-gated API requires CUBLAS_VERSION >= 130100, suggesting Blackwell-class hardware.

Check if >= 100 is intentional or should be checking specific Blackwell variants


transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 314
Potential mismatched dimension logic for column-wise storage. When use_columnwise=true, rows_A and cols_A are computed as ws.M and ws.K respectively, but for the transposed case the dimensions appear flipped (line 313-314 vs 317-318).

Verify the dimension logic is correct for both row-wise and column-wise layouts with all transpose combinations


transformer_engine/common/gemm/cublaslt_grouped_gemm.cu, line 459
Comment states "Input A is stored as {K,M} when !transa" but the dimension extraction logic doesn't match the comment. When transa=false, M[idx] = a_last (line 457), which means a_last represents M. However, if A is stored as {K,M} with !transa, a_last should be the second dimension which is M. The logic appears correct, but the comment is confusing about storage order vs logical order


tests/cpp/operator/test_grouped_gemm.cu, line 94
Test comment mislabeled. Comment says "Same M (first dim)" but the shapes show M=64 (same), N varying (80→96→112), and K varying (32→48→64). This is NOT same M with varying N and K - it's same M only. The test actually tests same-M-varying-N-and-K which is valid, but the comment is misleading

// Number of streaming multiprocessors to use in GEMM kernel
int sm_count = 0;

static constexpr size_t attr_sizes[] = {sizeof(avg_m), sizeof(avg_n), sizeof(avg_k),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with greptile, after the change to optional this is wrong - you would need to do something like

sizeof(decltype(avg_m)::value_type


// Compute data pointers
// Note: const_cast is safe here - cuBLAS requires void** but won't modify A/B/C data
A_ptrs[idx] = const_cast<char *>(a_base) + a_offset * a_elem_size;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't actually need this here - the fact that you got a const pointer to GroupedTensor does not mean that the pointer inside it is const pointer. So you can declare those a_base etc. pointers as regular char* and not have to const_cast them afterwards.

pggPL and others added 3 commits January 9, 2026 17:27
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Overview

Adds comprehensive grouped GEMM support via cuBLASLt's grouped matmul API, enabling efficient batched matrix multiplication on tensors with varying shapes. The implementation bridges the gap between NVTEGroupedTensor's contiguous buffer + metadata format and cuBLASLt's pointer array requirements.

Key Features Implemented

  1. GPU Setup Kernel (setup_grouped_gemm_kernel): Computes per-tensor pointers and M/N/K dimensions from GroupedTensor metadata, handling both uniform and varying shapes with optional per-tensor offsets.

  2. FP8 Support: Correctly handles FP8 scale_inv pointers and enforces TN-layout-only constraint on Hopper via columnwise data fallback and transpose flag adjustment.

  3. Configuration System: New GroupedMatmulConfig struct with optional avg_m/avg_n/avg_k hints for cuBLASLt algorithm heuristics, with proper C++ wrapper for ease of use.

  4. Workspace Management: Separate setup workspace (pointer arrays) and cuBLAS workspace, with proper alignment and size validation.

  5. API Design: Clear separation between C API (nvte_grouped_gemm) and C++ wrappers (GroupedMatmulConfigWrapper), with comprehensive documentation.

Validation & Testing

  • Comprehensive test suite (test_grouped_gemm.cu) compares against nvte_multi_tensor_gemm baseline
  • Tests all data types: FP8, BF16 with various shape configurations (uniform, varying first/last dimensions)
  • All transpose combinations and null C (beta=0) cases covered
  • Proper architectural checks for Blackwell+ and cuBLAS 13.1+ requirements

Code Quality

  • Proper input validation with detailed error messages
  • RAII memory management throughout (CudaPtr, GroupedTensorHandle, NVIDIA's custom deleters)
  • Helper functions (TensorShapeInfo, GroupedOperandSelection) encapsulate complex logic
  • Comments explain non-obvious decisions (e.g., FP8 TN-layout handling, M/N/K computation semantics)
  • Dimension computation is semantically correct for both transposed and non-transposed cases

No Issues Found

The implementation is production-ready with:

  • Correct attr_sizes computation using decltype(...) ::value_type
  • Proper SM100/100a/103a detection (sm_arch >= 100)
  • Safe nullptr handling for optional C tensor (falls back to D)
  • Correct workspace layout and alignment

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk. Implementation is well-designed, thoroughly tested, and handles edge cases properly.
  • Score reflects comprehensive implementation of a complex feature with excellent code quality. All 12 files reviewed showed proper error handling, memory management, and API consistency. Dimension computation logic is correct, FP8 constraints are properly enforced, and test coverage is thorough including edge cases (null C, varying shapes, all transpose combinations, dual data types). Previous review comments have been addressed. No critical, high, or medium severity issues identified.
  • No files require special attention. All implementations follow established patterns and have proper validation.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu 5/5 Core implementation of grouped GEMM. Bridges NVTEGroupedTensor format to cuBLASLt requirements via GPU setup kernel. Properly handles FP8 TN-layout constraints, varying shapes, and column-wise data fallback. M/N/K dimension computation is correct. C=nullptr fallback to D is properly handled. No critical issues found.
transformer_engine/common/gemm/config.h 5/5 GroupedMatmulConfig structure with optional avg_m/avg_n/avg_k hints. Correctly uses sizeof(decltype(avg_m)::value_type) = sizeof(int64_t) = 8 bytes, matching the memcpy operations in config.cpp. No issues.
transformer_engine/common/gemm/config.cpp 5/5 Configuration accessors for grouped matmul. Correctly extracts int64_t values from std::optional and performs memcpy with correct sizes. Matches attr_sizes from config.h. No issues found.
transformer_engine/common/include/transformer_engine/gemm.h 5/5 API declarations and C++ wrapper classes. Documentation clearly states Blackwell (SM100) + cuBLAS 13.1+ requirements. GroupedMatmulConfigWrapper provides clean C++ interface for setting avg_m/avg_n/avg_k hints. Comprehensive documentation. No issues.
tests/cpp/operator/test_grouped_gemm.cu 5/5 Comprehensive test suite comparing grouped_gemm against multi_tensor_gemm baseline. Tests FP8/BF16, all shape variations (uniform, varying first/last/all), transpose combinations, and null C case. Uses cudaStreamDefault (passed as 0) which is semantically correct. No issues found.
tests/cpp/test_common.cu 5/5 build_grouped_tensor implementation with random padding for alignment testing. Correctly computes offsets, logical shapes, and sets grouped tensor parameters. Allocates and copies FP8 columnwise data appropriately. No issues found.

Sequence Diagram

sequenceDiagram
    participant User
    participant nvte_grouped_gemm as nvte_grouped_gemm API
    participant setup_kernel as setup_grouped_gemm_kernel
    participant cublaslt as cuBLASLt
    
    User->>nvte_grouped_gemm: Call with GroupedTensors A, B, C, D
    nvte_grouped_gemm->>nvte_grouped_gemm: Validate inputs (tensors, shapes, dtypes)
    nvte_grouped_gemm->>nvte_grouped_gemm: Select operand storage (row/col-wise)
    nvte_grouped_gemm->>nvte_grouped_gemm: Handle FP8 TN-only constraint on Hopper
    nvte_grouped_gemm->>nvte_grouped_gemm: Allocate/validate workspaces
    nvte_grouped_gemm->>setup_kernel: Launch setup kernel (GPU)
    setup_kernel->>setup_kernel: Compute per-matrix pointers from offsets
    setup_kernel->>setup_kernel: Compute M/N/K from tensor shapes
    setup_kernel->>setup_kernel: Fill pointer arrays and dimension arrays
    nvte_grouped_gemm->>nvte_grouped_gemm: Initialize cuBLASLt descriptors
    nvte_grouped_gemm->>nvte_grouped_gemm: Set FP8 scale pointers (if needed)
    nvte_grouped_gemm->>nvte_grouped_gemm: Select algorithm via heuristics
    nvte_grouped_gemm->>cublaslt: Execute grouped GEMM
    cublaslt->>cublaslt: Perform batched matrix multiplication
    cublaslt-->>User: Return with results in D
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds nvte_grouped_gemm API using cuBLASLt grouped matmul for batched GEMM on tensors with varying shapes, requiring Blackwell (SM100+) and cuBLAS 13.1+.

Implementation Overview:

  • Core GEMM flow: GPU kernel (setup_grouped_gemm_kernel) converts NVTEGroupedTensor format (contiguous buffer + offsets) to cuBLAS requirements (pointer arrays + per-matrix M/N/K dimensions)
  • FP8 support: Handles scale_inv pointers and includes TN layout selection logic for Hopper compatibility
  • Configuration: GroupedMatmulConfig with optional avg_m/avg_n/avg_k hints for algorithm heuristics
  • Testing: Comprehensive tests comparing against nvte_multi_tensor_gemm across FP8/BF16, various shapes, and transpose configurations

Key Finding:
The setup kernel contains a logic bug in the FP8 TN layout forcing path (lines 453-458) where dimensions are computed incorrectly when columnwise data is selected. However, this code path is unreachable in practice because grouped GEMM requires SM100+ (Blackwell), which supports non-TN FP8 layouts, so the TN-forcing logic never executes. While not a runtime issue, this creates technical debt if requirements change.

Code Quality:

  • Clean separation of concerns with helper functions
  • Proper error checking and validation
  • Good test coverage for supported configurations
  • RAII memory management in test infrastructure

Confidence Score: 4/5

  • Safe to merge - core functionality is correct and well-tested for the supported SM100+ configuration
  • The implementation is sound for its intended SM100+ target. One logic bug exists in unreachable TN-forcing code (P2 severity) that only affects hypothetical Hopper support. All tested paths work correctly. Strong test coverage validates the main use cases.
  • transformer_engine/common/gemm/cublaslt_grouped_gemm.cu - contains unreachable but incorrect FP8 TN layout handling code

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu 3/5 New file implementing nvte_grouped_gemm with cuBLASLt. Contains setup kernel for pointer array generation, FP8 scale handling, and operand selection logic. Has untested FP8 TN layout forcing code that would be incorrect if used on Hopper (but SM100+ requirement prevents this).
transformer_engine/common/gemm/config.h 5/5 Adds GroupedMatmulConfig struct with optional avg_m/n/k hints and sm_count. attr_sizes correctly uses sizeof(value_type) for std::optional fields.
transformer_engine/common/gemm/config.cpp 5/5 Implements create/get/set/destroy functions for GroupedMatmulConfig. Properly handles std::optional by transferring value_type (int64_t) with value_or(0) for gets and direct assignment for sets.
tests/cpp/operator/test_grouped_gemm.cu 4/5 Comprehensive tests comparing nvte_grouped_gemm against nvte_multi_tensor_gemm baseline. Tests FP8/BF16, various transpose configs, and shape variations. Only runs on SM100+ which has non-TN FP8 support.
tests/cpp/test_common.cu 5/5 Adds build_grouped_tensor helper with RAII memory management. Correctly handles varying shapes, random padding for alignment testing, and creates both rowwise and columnwise data for FP8.
tests/cpp/test_common.h 5/5 Adds GroupedBuffers struct and build_grouped_tensor declaration. Clean RAII design with CudaPtr and GroupedTensorHandle wrappers.

Sequence Diagram

sequenceDiagram
    participant User
    participant nvte_grouped_gemm
    participant select_grouped_operand
    participant setup_kernel
    participant cuBLASLt

    User->>nvte_grouped_gemm: Call with A, B, C, D tensors
    nvte_grouped_gemm->>nvte_grouped_gemm: Validate inputs (SM100+, cuBLAS 13.1+)
    nvte_grouped_gemm->>select_grouped_operand: Select A operand (row/col-wise)
    select_grouped_operand-->>nvte_grouped_gemm: A_sel (adjusted trans, dptr, dtype)
    nvte_grouped_gemm->>select_grouped_operand: Select B operand (row/col-wise)
    select_grouped_operand-->>nvte_grouped_gemm: B_sel (adjusted trans, dptr, dtype)
    nvte_grouped_gemm->>nvte_grouped_gemm: Allocate setup + cuBLAS workspaces
    nvte_grouped_gemm->>setup_kernel: Launch GPU kernel
    setup_kernel->>setup_kernel: Compute pointer arrays (A/B/C/D_ptrs)
    setup_kernel->>setup_kernel: Compute M/N/K per matrix
    setup_kernel->>setup_kernel: Setup alpha/beta pointers
    setup_kernel-->>nvte_grouped_gemm: Populated workspace arrays
    nvte_grouped_gemm->>nvte_grouped_gemm: Init matrix layouts (descA/B/C/D)
    nvte_grouped_gemm->>nvte_grouped_gemm: Init matmul descriptor with ops
    nvte_grouped_gemm->>nvte_grouped_gemm: Set FP8 scale pointers if needed
    nvte_grouped_gemm->>nvte_grouped_gemm: Select algorithm via heuristics
    nvte_grouped_gemm->>cuBLASLt: cublasLtMatmul
    cuBLASLt-->>nvte_grouped_gemm: GEMM results in D
    nvte_grouped_gemm-->>User: Return
Loading

Comment on lines +453 to +458
// Compute M, N, K dimensions from tensor shapes
// Input A is stored as {K,M} when !transa, {M,K} when transa
// Input B is stored as {N,K} when !transb, {K,N} when transb
M[idx] = static_cast<int>(transa ? a_first : a_last);
K[idx] = static_cast<int>(transa ? a_last : a_first);
N[idx] = static_cast<int>(transb ? b_last : b_first);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] Setup kernel uses wrong dimensions when columnwise data is selected. When select_grouped_operand switches to columnwise data for FP8 TN layout (lines 249-255), it flips the transpose flag but the setup kernel still uses first_dims/last_dims which describe the rowwise shape, not the columnwise shape. For example, if rowwise A is {K, M} and we switch to columnwise with trans=true, the kernel computes M[idx]=a_first=K and K[idx]=a_last=M (swapped). This bug is currently unreachable because grouped GEMM requires SM100+ (line 512) where non-TN FP8 is supported, so the TN-forcing code never executes. However, this creates a maintenance risk if the SM requirement is lowered or the pattern is copied elsewhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, it's bigger problem. it does not result in any bugs, but i think the code needs to be slightly redesigned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GroupedGemm: FP8 per-tensor via cuBLAS

4 participants